#!/usr/bin/env python
# coding: utf-8

# In[2]:


import numpy as np
import pandas as pd
import matplotlib.pylab as plt
import ot
import cvxpy as cp

import seaborn as sns


# ## Functions

# In[3]:


def baryc_proj(source, target, method = 'emd'):
    
    n1 = source.shape[0]
    n2 = target.shape[0]
    p = source.shape[1]
    a_ones, b_ones = np.ones((n1,)) / n1, np.ones((n2,)) / n2
    
    M = ot.dist(source, target)
    M = M.astype('float64')
    M /= M.max()
    
    if method == 'emd':
        OTplan = ot.emd(a_ones, b_ones, M, numItermax = 1e7)
        
    elif method == 'entropic':
        OTplan = ot.bregman.sinkhorn_stabilized(a_ones, b_ones, M, reg = 5*1e-3)
    
    # initialization
    OTmap = np.empty((0, p))

    for i in range(n1):
        
        # normalization
        OTplan[i,:] = OTplan[i,:] / sum(OTplan[i,:])
    
        # obtain conditional expectation
        OTmap = np.vstack([OTmap, (target.T @ OTplan[i,:])])
    
    OTmap = np.array(OTmap).astype('float32')
    
    return(OTmap)


# In[4]:


def DSCreplication(target, controls, method = 'emd', projtype = 'wass'):
    
    n = target.shape[0]
    d = target.shape[1]
    J = len(controls)
    S = np.mean(target)*n*d*J # Stabilizer: to ground the optimization objective
    
    
    # Barycentric Projection
    G_list = []
    proj_list = []
    for i in range(len(controls)):
        temp = baryc_proj(target, controls[i], method)
        G_list.append(temp)
        proj_list.append(temp - target)
    
    
    # Obtain optimal weights
    mylambda = cp.Variable(J)

    objective = cp.Minimize(
                    cp.sum_squares(
                    cp.sum([a*b for a,b in zip(mylambda, proj_list)], axis = 0))/S
                    )
    
    constraints = [mylambda >= 0, mylambda <= 1, cp.sum(mylambda) == 1]

    prob = cp.Problem(objective, constraints)
    prob.solve()

    weights = mylambda.value
    testproj = sum([a*b for a,b in zip(weights, G_list)])
    measureweights = [ot.unif(n)]*J
    print('optimized')
    
    if projtype == 'eucl':
        projection = testproj
    elif projtype == 'wass':
        projection = ot.lp.free_support_barycenter(G_list, measureweights, X_init = testproj, 
                                                               weights = weights, numItermax = 1e7)

    
    return(weights, projection)


# ## Mixed Multivariate Normal

# In[5]:


def mixed_multi_gauss(mean1, mean2, mean3, cov1, cov2, cov3, samplesize, partition1, partition2):
    
    size1 = int(samplesize * partition1)
    size2 = int(samplesize * partition2)
    size3 = int(samplesize - size1 - size2)
    
    gauss1 = np.random.multivariate_normal(mean = mean1, cov = cov1, size = size1)
    gauss2 = np.random.multivariate_normal(mean = mean2, cov = cov2, size = size2)
    gauss3 = np.random.multivariate_normal(mean = mean3, cov = cov3, size = size3)
    
    mixed = np.concatenate((gauss1, gauss2, gauss3), axis = 0)
    np.random.shuffle(mixed)
    
    return(mixed)

def multi_gauss_mix(gauss_set, partition1, partition2):
    
    mixed = np.concatenate(gauss_set, axis = 0)
    np.random.shuffle(mixed)
    
    return(mixed)


# In[6]:


np.random.seed(31)

dim = 10
obs = 1000

mu1 = [10]*dim
mu2 = [50]*dim
mu3 = [200]*dim
mu4 = [-50]*dim
mu5 = [-100]*dim

covmat = np.full((dim, dim), 0.5)
np.fill_diagonal(covmat, 1)

X1 = np.random.multivariate_normal(mean = mu1, cov = covmat, size = obs)
X2 = np.random.multivariate_normal(mean = mu2, cov = covmat, size = obs)
X3 = np.random.multivariate_normal(mean = mu3, cov = covmat, size = obs)
X4 = np.random.multivariate_normal(mean = mu4, cov = covmat, size = obs)

Y1 = mixed_multi_gauss(mu1, mu2, mu3, covmat, covmat, covmat, obs, 0.7, 0.15)
Y2 = mixed_multi_gauss(mu1, mu2, mu3, covmat, covmat, covmat, obs, 0.6, 0.3)
Y3 = mixed_multi_gauss(mu2, mu3, mu4, covmat, covmat, covmat, obs, 0.7, 0.2)
Y4 = mixed_multi_gauss(mu1, mu3, mu4, covmat, covmat, covmat, obs, 0.3, 0.1) 


# # X

# In[7]:


weightsx, replicationx = DSCreplication(X1, [X2, X3, X4]) # check
weightsx


# In[11]:


sns.kdeplot(np.mean(X1, axis = 1), color = 'royalblue', alpha = 1)
sns.kdeplot(np.mean(replicationx, axis = 1), color = 'orange', alpha = 0.5)

#plt.savefig('Gauss.png', transparent = True)


# # Y

# In[9]:


weightsy, replicationy = DSCreplication(Y1, [Y2, Y3, Y4], projtype = 'wass')
weightsy


# In[10]:


sns.kdeplot(np.mean(Y1, axis=1), color = 'royalblue')
sns.kdeplot(np.mean(replicationy, axis=1), color = 'darkorange', alpha = 0.5)

#plt.savefig('GaussMulti.png', transparent = True)

